{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup rendering dependencies for Google Colaboratory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install gym pyvirtualdisplay > /dev/null 2>&1\n",
    "!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install d3rlpy!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install d3rlpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup cartpole environment."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "\n",
    "env = gym.make('CartPole-v0')\n",
    "eval_env = gym.make('CartPole-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup data-driven deep reinforcement learning algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from d3rlpy.algos import DQN\n",
    "from d3rlpy.online.explorers import ConstantEpsilonGreedy\n",
    "from d3rlpy.online.buffers import ReplayBuffer\n",
    "from d3rlpy.metrics.scorer import evaluate_on_environment\n",
    "\n",
    "# setup DQN algorithm\n",
    "dqn = DQN(learning_rate=1e-3, target_update_interval=100, use_gpu=False)\n",
    "\n",
    "# setup explorer\n",
    "explorer = ConstantEpsilonGreedy(epsilon=0.3)\n",
    "\n",
    "# setup replay buffer\n",
    "buffer = ReplayBuffer(maxlen=50000, env=env)\n",
    "\n",
    "# start training\n",
    "dqn.fit_online(env,\n",
    "               buffer,\n",
    "               explorer,\n",
    "               eval_env=eval_env,\n",
    "               n_steps=50000,\n",
    "               n_steps_per_epoch=10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Setup rendering utilities for Google Colaboratory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import io\n",
    "import base64\n",
    "\n",
    "from gym.wrappers import Monitor\n",
    "from IPython.display import HTML\n",
    "from IPython import display as ipythondisplay\n",
    "from pyvirtualdisplay import Display\n",
    "\n",
    "# start virtual display\n",
    "display = Display(visible=0, size=(1400, 900))\n",
    "display.start()\n",
    "\n",
    "# play recorded video\n",
    "def show_video():\n",
    "    mp4list = glob.glob('video/*.mp4')\n",
    "    if len(mp4list) > 0:\n",
    "        mp4 = mp4list[0]\n",
    "        video = io.open(mp4, 'r+b').read()\n",
    "        encoded = base64.b64encode(video)\n",
    "        ipythondisplay.display(HTML(data='''\n",
    "            <video alt=\"test\" autoplay loop controls style=\"height: 400px;\">\n",
    "                <source src=\"data:video/mp4;base64,{0}\" type=\"video/mp4\" />\n",
    "            </video>'''.format(encoded.decode('ascii'))))\n",
    "    else: \n",
    "        print(\"Could not find video\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Record video!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# wrap Monitor wrapper\n",
    "env = Monitor(env, './video', force=True)\n",
    "\n",
    "# evaluate\n",
    "evaluate_on_environment(env)(dqn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's see how it works!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "show_video()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
